import os
import time
from unittest.mock import patch

import pytest

from zenml import pipeline, step
from zenml.config import CachePolicy
from zenml.constants import ENV_ZENML_PREVENT_CLIENT_SIDE_CACHING
from zenml.exceptions import EntityExistsError


@step(enable_cache=False)
def constant_int_output_test_step() -> int:
    return 42


def test_pipeline_run_returns_up_to_date_run_info():
    @pipeline
    def _pipeline():
        constant_int_output_test_step()

    pipeline_run_info = _pipeline()

    assert "constant_int_output_test_step" in pipeline_run_info.steps
    assert (
        pipeline_run_info.steps["constant_int_output_test_step"].status
        == "completed"
    )


@step
def noop() -> None:
    pass


@step(enable_cache=False)
def simple_step_for_duplicate_test() -> int:
    return 42


def test_pipeline_run_computes_clientside_cache(clean_client, mocker):
    """Tests that running a pipeline computes the cached steps client-side and
    only forwards the non-cached steps to the orchestrator.
    """
    step_with_cache_enabled = noop.with_options(enable_cache=True)
    step_with_cache_disabled = noop.with_options(enable_cache=False)

    @pipeline
    def partial_cached_pipeline():
        step_with_cache_enabled(id="step_1")
        step_with_cache_enabled(id="step_2", after="step_1")
        step_with_cache_disabled(id="step_3", after="step_1")
        step_with_cache_enabled(id="step_4", after="step_3")

    partial_cached_pipeline()

    mock_submit_pipeline = mocker.patch.object(
        clean_client.active_stack.orchestrator, "submit_pipeline"
    )
    partial_cached_pipeline()
    assert mock_submit_pipeline.call_count == 1

    _, call_kwargs = mock_submit_pipeline.call_args
    assert set(call_kwargs["snapshot"].step_configurations.keys()) == {
        "step_3",
        "step_4",
    }


def test_fully_cached_pipeline_doesnt_call_orchestrator_implementation(
    clean_client, mocker
):
    """Tests that a fully cached pipeline does not get forwarded to the
    specific orchestrator implementation."""
    step_with_cache_enabled = noop.with_options(enable_cache=True)

    @pipeline
    def full_cached_pipeline():
        step_with_cache_enabled(id="step_1")
        step_with_cache_enabled(id="step_2", after="step_1")

    full_cached_pipeline()

    mock_prepare_or_run_pipeline = mocker.patch.object(
        clean_client.active_stack.orchestrator, "prepare_or_run_pipeline"
    )
    full_cached_pipeline()
    mock_prepare_or_run_pipeline.assert_not_called()


def test_environment_variable_can_be_used_to_disable_clientside_caching(
    clean_client, mocker
):
    """Tests that the environment variable to disable client-side caching
    works.
    """
    step_with_cache_enabled = noop.with_options(enable_cache=True)

    @pipeline
    def full_cached_pipeline():
        step_with_cache_enabled(id="step_1")
        step_with_cache_enabled(id="step_2", after="step_1")

    full_cached_pipeline()

    mock_submit_pipeline = mocker.patch.object(
        clean_client.active_stack.orchestrator, "submit_pipeline"
    )

    with patch.dict(
        os.environ, {ENV_ZENML_PREVENT_CLIENT_SIDE_CACHING: "True"}
    ):
        full_cached_pipeline()

    mock_submit_pipeline.assert_called()


def test_duplicate_pipeline_run_name_raises_improved_error(clean_client):
    """Test that running a pipeline twice with the same name raises an improved error message."""

    @pipeline
    def test_pipeline():
        simple_step_for_duplicate_test()

    # First run should succeed
    run_name = "duplicate_name_test_run"
    first_run = test_pipeline.with_options(run_name=run_name)()
    assert first_run.name == run_name

    # Second run with same name should raise EntityExistsError with clear message
    with pytest.raises(EntityExistsError) as exc_info:
        test_pipeline.with_options(run_name=run_name)()

    error_message = str(exc_info.value)

    # Verify it contains a clear duplicate name error message
    assert (
        "already exists" in error_message.lower()
        or "existing pipeline run with the same name" in error_message.lower()
        or f"Pipeline run name '{run_name}' already exists" in error_message
    )


@step
def cacheable_step() -> int:
    return 0


def test_cache_expiration(clean_client):
    """Test cache expiration."""

    @pipeline(cache_policy=CachePolicy(expires_after=1))
    def test_pipeline():
        cacheable_step()

    test_pipeline()
    time.sleep(1)
    run = test_pipeline()
    assert run.steps["cacheable_step"].status == "completed"

    @pipeline(cache_policy=CachePolicy(expires_after=3600))
    def test_pipeline():
        cacheable_step()

    test_pipeline()
    run = test_pipeline()
    assert run.steps["cacheable_step"].status == "cached"
